{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Uniform Distribution\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.distributions import Uniform\n",
    "\n",
    "# Set the parameters of the distribution\n",
    "a = torch.tensor([1.0], dtype=torch.float)\n",
    "b = torch.tensor([5.0], dtype=torch.float)\n",
    "\n",
    "# Instantiate the uniform distribution\n",
    "ufm_dist = Uniform(a, b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Log Prob: -1.386\n",
      "Raw eval Log Prob: -1.386\n"
     ]
    }
   ],
   "source": [
    "# Instantiate single point test dataset\n",
    "X = torch.tensor([2.0], dtype=torch.float)\n",
    "\n",
    "# Function to evaluate log prob using math formula\n",
    "def raw_eval(X, a, b):\n",
    "    return torch.log(1 / (b - a))\n",
    "\n",
    "# Evaluate log-prob using PyTorch distributions function call\n",
    "log_prob = ufm_dist.log_prob(X)\n",
    "print(\"Log Prob: {:.3f}\".format(log_prob[0]))\n",
    "\n",
    "# Evaluate log-prob using formula\n",
    "raw_eval_log_prob = raw_eval(X, a, b)\n",
    "print(\"Raw eval Log Prob: {:.3f}\".format(raw_eval_log_prob[0]))\n",
    "\n",
    "assert torch.isclose(log_prob, raw_eval_log_prob, atol=1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of samples to draw\n",
    "num_samples = 100000\n",
    "\n",
    "# Draw samples\n",
    "samples = ufm_dist.sample([num_samples])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sample Mean: 2.9954352378845215\n",
      "Dist Mean: 3.000\n",
      "Sample Variance: 1.335\n",
      "Dist Variance: 1.333\n"
     ]
    }
   ],
   "source": [
    "# The mean obtained from the samples\n",
    "sample_mean = samples.mean()\n",
    "print(\"Sample Mean: {}\".format(sample_mean))\n",
    "\n",
    "# The mean of the distribution from Pytorch\n",
    "dist_mean = ufm_dist.mean\n",
    "print(\"Dist Mean: {:.3f}\".format(dist_mean[0]))\n",
    "\n",
    "# As expected, the two means approximately match\n",
    "assert torch.isclose(sample_mean, dist_mean, atol=0.2)\n",
    "\n",
    "# The variance obtained from the samples\n",
    "sample_var = ufm_dist.sample([num_samples]).var()\n",
    "print(\"Sample Variance: {:.3f}\".format(sample_var))\n",
    "\n",
    "# The variance of the distribution from Pytorch\n",
    "dist_var = ufm_dist.variance\n",
    "print(\"Dist Variance: {:.3f}\".format(dist_var[0]))\n",
    "\n",
    "# As expected, the two variances approximately match\n",
    "assert torch.isclose(sample_var, dist_var, atol=0.2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Interactive Visualization\n",
    "In this section, we allow the user to set different values for the parameters $a$ and $b$ of the univariate uniform distribution and visualise the resulting uniform distribution. Observe that the area under the curve always remains 1 because it is a probability distribution\n",
    "\n",
    "Note: In order to run this section, please download the notebook. Interactive snippets do not work online. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "09489b94a29e43ef96c159c95275ed96",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to  previous…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7313e1aacafc4a28ac9a8b2be96e2413",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "interactive(children=(FloatSlider(value=25.0, description='a', max=40.0, min=10.0, step=5.0), FloatSlider(valu…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib widget\n",
    "import ipywidgets as widgets\n",
    "from ipywidgets import interact, interact_manual\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "\n",
    "\n",
    "@interact\n",
    "def plot_uniform_dist(a=(10, 40, 5.0), \n",
    "                      b=(50, 80, 5.0)):\n",
    "    def _clear_plot(ax):\n",
    "        ax.clear()\n",
    "        ax.set_title(\"Univariate Uniform Distribution\")\n",
    "        ax.set_ylabel(\"P(X)\")\n",
    "        ax.set_xlabel(\"X\")\n",
    "        ax.set_ylim(0, 0.12)\n",
    "    x = np.linspace(0, 90, 1000)\n",
    "    ufm_dist = Uniform(torch.tensor([a], dtype=torch.float), torch.tensor([b], dtype=torch.float))\n",
    "    prob = ufm_dist.log_prob(torch.from_numpy(x).float()).exp()\n",
    "    _clear_plot(ax)\n",
    "    ax.plot(x, prob, 'green')    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
